/[finished]Assignment_2_word2vec/sgd.py

https://github.com/flypythoncom/cs224n_2019
Python | 133 lines | 76 code | 28 blank | 29 comment | 19 complexity | f7146c6801df0558c149d1306758cf81 MD5 | raw file
  1. #!/usr/bin/env python
  2. # Save parameters every a few SGD iterations as fail-safe
  3. SAVE_PARAMS_EVERY = 5000
  4. import pickle
  5. import glob
  6. import random
  7. import numpy as np
  8. import os.path as op
  9. def load_saved_params():
  10. """
  11. A helper function that loads previously saved parameters and resets
  12. iteration start.
  13. """
  14. st = 0
  15. for f in glob.glob("saved_params_*.npy"):
  16. iter = int(op.splitext(op.basename(f))[0].split("_")[2])
  17. if (iter > st):
  18. st = iter
  19. if st > 0:
  20. params_file = "saved_params_%d.npy" % st
  21. state_file = "saved_state_%d.pickle" % st
  22. params = np.load(params_file)
  23. with open(state_file, "rb") as f:
  24. state = pickle.load(f)
  25. return st, params, state
  26. else:
  27. return st, None, None
  28. def save_params(iter, params):
  29. params_file = "saved_params_%d.npy" % iter
  30. np.save(params_file, params)
  31. with open("saved_state_%d.pickle" % iter, "wb") as f:
  32. pickle.dump(random.getstate(), f)
  33. def sgd(f, x0, step, iterations, postprocessing=None, useSaved=False,
  34. PRINT_EVERY=10):
  35. """ Stochastic Gradient Descent
  36. Implement the stochastic gradient descent method in this function.
  37. Arguments:
  38. f -- the function to optimize, it should take a single
  39. argument and yield two outputs, a loss and the gradient
  40. with respect to the arguments
  41. x0 -- the initial point to start SGD from
  42. step -- the step size for SGD
  43. iterations -- total iterations to run SGD for
  44. postprocessing -- postprocessing function for the parameters
  45. if necessary. In the case of word2vec we will need to
  46. normalize the word vectors to have unit length.
  47. PRINT_EVERY -- specifies how many iterations to output loss
  48. Return:
  49. x -- the parameter value after SGD finishes
  50. """
  51. # Anneal learning rate every several iterations
  52. ANNEAL_EVERY = 20000
  53. if useSaved:
  54. start_iter, oldx, state = load_saved_params()
  55. if start_iter > 0:
  56. x0 = oldx
  57. step *= 0.5 ** (start_iter / ANNEAL_EVERY)
  58. if state:
  59. random.setstate(state)
  60. else:
  61. start_iter = 0
  62. x = x0
  63. if not postprocessing:
  64. postprocessing = lambda x: x
  65. exploss = None
  66. for iter in range(start_iter + 1, iterations + 1):
  67. # You might want to print the progress every few iterations.
  68. loss = None
  69. ### YOUR CODE HERE
  70. loss,gd = f(x)
  71. x = x - step*gd
  72. x = postprocessing(x)
  73. ### END YOUR CODE
  74. x = postprocessing(x)
  75. if iter % PRINT_EVERY == 0:
  76. if not exploss:
  77. exploss = loss
  78. else:
  79. exploss = .95 * exploss + .05 * loss
  80. print("iter %d: %f" % (iter, exploss))
  81. if iter % SAVE_PARAMS_EVERY == 0 and useSaved:
  82. save_params(iter, x)
  83. if iter % ANNEAL_EVERY == 0:
  84. step *= 0.5
  85. return x
  86. def sanity_check():
  87. quad = lambda x: (np.sum(x ** 2), x * 2)
  88. print("Running sanity checks...")
  89. t1 = sgd(quad, 0.5, 0.01, 1000, PRINT_EVERY=100)
  90. print("test 1 result:", t1)
  91. assert abs(t1) <= 1e-6
  92. t2 = sgd(quad, 0.0, 0.01, 1000, PRINT_EVERY=100)
  93. print("test 2 result:", t2)
  94. assert abs(t2) <= 1e-6
  95. t3 = sgd(quad, -1.5, 0.01, 1000, PRINT_EVERY=100)
  96. print("test 3 result:", t3)
  97. assert abs(t3) <= 1e-6
  98. print("-" * 40)
  99. print("ALL TESTS PASSED")
  100. print("-" * 40)
  101. if __name__ == "__main__":
  102. sanity_check()